#include <assert.h>
#include <stdio.h>

#include "../src/mstep.c"

typedef struct _MstepTestDatum MstepTestDatum ;
struct _MstepTestDatum 
{ 
  struct
    {
      float    alpha[4 * 5 * 5];
      float    logbeta;
      float    gamma[5 * 5];
      float    pz[5];
      float    priorz[5];
      Rating   ratings[8];
      float    alphastddev;
      float    logbetamean;
      float    logbetastddev;
      float    gammamean;
      float    gammastddev;
    }                   inputs;

  struct
    {
      float    newq;
      float    logbeta;
    }                   desired_output;
};

static void
test_mstep (void)
{
  MstepTestDatum data[] = 
    {{{{-1., 0.255605, 0.504832, 0.481769, 0.85996, 0.643395, -1.,  0.759951, 0.150263, 0.715673, 0.615078, 0.250291, -1., 0.338055,  0.131939, 0.407173, 0.453593, 0.181041, -1., 0.0278867, 0.899389,  0.416135, 0.415832, 0.881447, -1., -1., 0.513266, 0.550724,  0.916821, 0.831783, 0.257661, -1., 0.0458914, 0.435052, 0.971823,  0.614266, 0.285941, -1., 0.284789, 0.25615, 0.999188, 0.0356495,  0.946733, -1., 0.124211, 0.592015, 0.582056, 0.765693,  0.0963248, -1., -1., 0.692626, 0.165921, 0.34986, 0.214878,  0.17936, -1., 0.615197, 0.43304, 0.383095, 0.921699,  0.569306, -1., 0.997988, 0.411272, 0.307433, 0.283365,  0.713199, -1., 0.155121, 0.308245, 0.247716, 0.766466,  0.03091, -1., -1., 0.71623, 0.665659, 0.0007735, 0.934585,  0.023604, -1., 0.499738, 0.650913, 0.719707, 0.844244,  0.884541, -1., 0.217873, 0.336612, 0.922544, 0.315235,  0.219885, -1., 0.92534, 0.615111, 0.0318698, 0.506686,  0.770219, -1.},  0., {1., 0.306866, 0.784154, 0.74022, 0.739309, 0.590635, 1.,  0.118495, 0.739446, 0.804723, 0.567031, 0.618756, 1., 0.0885332,  0.0850165, 0.722788, 0.734215, 0.87066, 1., 0.748404, 0.800243,  0.41898, 0.650775, 0.823064, 1.}, {0.185132, 0.38711, 0.144089,  0.0528454, 0.878267}, {0.602956, 0.403869, 0.313537, 0.287632,  0.484462}, {{0, 2}, {1, 0}, {3, 2}, {2, 2}, {0, 4}, {3, 2}, {2,  1}, {1, 4}}, 0.664423, 0.508813, 0.7206, 0.865705,  0.57589}, {-22.700118077829739190, -0.56786873742070880392}},  {{{-1., 0.423797, 0.997813, 0.13149, 0.70523, 0.675392, -1., 0.19757,  0.71251, 0.0544548, 0.852328, 0.0124372, -1., 0.325399, 0.910366,  0.799483, 0.13417, 0.722443, -1., 0.506497, 0.485946, 0.846539,  0.237981, 0.842074, -1., -1., 0.977133, 0.125938, 0.372276,  0.266184, 0.553336, -1., 0.128125, 0.240786, 0.560955, 0.877944,  0.930555, -1., 0.528276, 0.5065, 0.025616, 0.918118,  0.202877, -1., 0.596134, 0.226133, 0.783948, 0.480434,  0.0896372, -1., -1., 0.740187, 0.93741, 0.242453, 0.247563,  0.763054, -1., 0.811472, 0.870177, 0.981379, 0.209718,  0.683346, -1., 0.629391, 0.420425, 0.331774, 0.752791,  0.101115, -1., 0.913925, 0.306158, 0.834673, 0.898239,  0.317791, -1., -1., 0.080025, 0.0507246, 0.417805, 0.228154,  0.339838, -1., 0.113315, 0.175352, 0.98059, 0.576783,  0.301844, -1., 0.305175, 0.999211, 0.367065, 0.618497,  0.675783, -1., 0.578786, 0.0352908, 0.865706, 0.574668,  0.664862, -1.},  0., {1., 0.729132, 0.0310334, 0.676429, 0.347071, 0.649107, 1.,  0.980309, 0.258625, 0.118917, 0.30927, 0.866994, 1., 0.0832728,  0.138327, 0.732487, 0.56515, 0.778098, 1., 0.139116, 0.365422,  0.946653, 0.102315, 0.56033, 1.}, {0.330131, 0.080947, 0.527647,  0.895468, 0.600998}, {0.0499136, 0.851218, 0.548397, 0.951891,  0.0696049}, {{1, 4}, {3, 0}, {1, 3}, {0, 0}, {1, 3}, {3, 3}, {0,  2}, {1, 4}}, 0.592594, 0.42948, 0.642621, 0.202611,  0.509321}, {-31.959008660449421512, -0.50025692085735161786}},  {{{-1., 0.291152, 0.910135, 0.637461, 0.731223, 0.152036, -1.,  0.544713, 0.690808, 0.628908, 0.591707, 0.214582, -1., 0.609861,  0.10126, 0.696239, 0.613584, 0.559947, -1., 0.250042, 0.147842,  0.661693, 0.490342, 0.657449, -1., -1., 0.718363, 0.0190721,  0.287731, 0.148128, 0.42721, -1., 0.108938, 0.65027, 0.416905,  0.275174, 0.564225, -1., 0.959462, 0.787998, 0.683467, 0.349642,  0.349602, -1., 0.686737, 0.987228, 0.736058, 0.789654,  0.436695, -1., -1., 0.839386, 0.0743652, 0.299312, 0.779246,  0.121023, -1., 0.0552932, 0.011581, 0.631118, 0.693813,  0.946356, -1., 0.361311, 0.214213, 0.418639, 0.382131,  0.401848, -1., 0.426215, 0.735172, 0.0324889, 0.0522468,  0.739478, -1., -1., 0.747944, 0.296431, 0.262592, 0.302783,  0.908558, -1., 0.222065, 0.96328, 0.523536, 0.787536,  0.166772, -1., 0.951699, 0.892418, 0.0937229, 0.220416,  0.590389, -1., 0.678205, 0.675084, 0.838285, 0.18854,  0.25199, -1.},  0., {1., 0.939912, 0.805796, 0.136293, 0.512513, 0.191968, 1.,  0.509366, 0.873701, 0.20973, 0.28341, 0.287301, 1., 0.910421,  0.686194, 0.495874, 0.120529, 0.958721, 1., 0.793775, 0.402151,  0.900112, 0.368333, 0.11557, 1.}, {0.727067, 0.0618268, 0.179793,  0.86358, 0.787155}, {0.25603, 0.0434994, 0.351067, 0.595187,  0.746664}, {{2, 4}, {3, 4}, {2, 1}, {1, 3}, {0, 1}, {1, 1}, {1,  3}, {0, 4}}, 0.169798, 0.141337, 0.311777, 0.459364,  0.259378}, {-36.030886295812603848, -0.34874057758939136827}},  {{{-1., 0.455144, 0.815903, 0.338835, 0.300656, 0.661368, -1.,  0.413752, 0.438723, 0.932323, 0.545798, 0.686684, -1., 0.376896,  0.752531, 0.682218, 0.89953, 0.120866, -1., 0.709031, 0.331151,  0.304343, 0.374202, 0.539233, -1., -1., 0.189813, 0.992566,  0.914838, 0.279855, 0.73467, -1., 0.176663, 0.576003, 0.979199,  0.0733013, 0.762912, -1., 0.13728, 0.0468756, 0.527503, 0.0762273,  0.760383, -1., 0.294345, 0.845285, 0.176698, 0.639517,  0.585314, -1., -1., 0.514134, 0.872355, 0.265316, 0.0460808,  0.324321, -1., 0.879789, 0.350478, 0.766226, 0.589651,  0.703126, -1., 0.774475, 0.787027, 0.51635, 0.940214,  0.637196, -1., 0.740151, 0.988847, 0.863987, 0.876812,  0.445806, -1., -1., 0.143562, 0.687289, 0.237295, 0.860492,  0.629427, -1., 0.814934, 0.971979, 0.814411, 0.305107,  0.935145, -1., 0.621501, 0.0481859, 0.715456, 0.23202,  0.847026, -1., 0.261159, 0.199106, 0.291806, 0.20983,  0.521008, -1.},  0., {1., 0.21026, 0.427819, 0.333018, 0.0752026, 0.0666981, 1.,  0.74053, 0.0957229, 0.21471, 0.437271, 0.925596, 1., 0.123744,  0.400299, 0.132164, 0.990451, 0.502243, 1., 0.352113, 0.416708,  0.758431, 0.655218, 0.0909538, 1.}, {0.217602, 0.466625, 0.445388,  0.569945, 0.00734208}, {0.0388059, 0.11237, 0.494743, 0.940644,  0.298276}, {{3, 3}, {2, 1}, {0, 4}, {1, 4}, {0, 4}, {0, 4}, {1,  3}, {2, 2}}, 0.016647, 0.280032, 0.503373, 0.372679,  0.892903}, {-24.544491182208209689, -0.58629523542722536255}},  {{{-1., 0.879733, 0.371209, 0.382229, 0.39066, 0.52762, -1., 0.954501,  0.623798, 0.735442, 0.436666, 0.7369, -1., 0.157172, 0.290055,  0.866721, 0.729558, 0.118367, -1., 0.177685, 0.371978, 0.788914,  0.820091, 0.161038, -1., -1., 0.0919459, 0.28554, 0.447412,  0.268135, 0.212213, -1., 0.914331, 0.0651829, 0.877475, 0.684593,  0.959829, -1., 0.441385, 0.142033, 0.247927, 0.22293,  0.284213, -1., 0.851978, 0.381206, 0.493372, 0.165846,  0.674293, -1., -1., 0.00922743, 0.704458, 0.345755, 0.513255,  0.917281, -1., 0.418918, 0.898344, 0.24512, 0.705069,  0.504587, -1., 0.833161, 0.367645, 0.020476, 0.544758,  0.391776, -1., 0.225613, 0.772549, 0.321828, 0.107563,  0.373635, -1., -1., 0.391344, 0.828456, 0.941717, 0.699342,  0.382116, -1., 0.123997, 0.595961, 0.186087, 0.464835,  0.705079, -1., 0.697617, 0.940967, 0.759766, 0.200492, 0.864457,  -1., 0.573322, 0.73929, 0.655734, 0.472681, 0.347709, -1.},  0., {1., 0.966741, 0.333906, 0.365118, 0.974074, 0.575397, 1.,  0.505451, 0.423402, 0.274732, 0.19328, 0.381453, 1., 0.82744,  0.0886453, 0.728445, 0.676374, 0.129823, 1., 0.147679, 0.968679,  0.475882, 0.265366, 0.574357, 1.}, {0.229389, 0.820148, 0.792685,  0.226648, 0.262648}, {0.486241, 0.427567, 0.252574, 0.687251,  0.980791}, {{2, 0}, {0, 4}, {1, 1}, {3, 1}, {0, 1}, {0, 3}, {3,  1}, {0, 0}}, 0.00416562, 0.977842, 0.493971, 0.599337,  0.176725}, {-34.495615541983055659, -0.29073678308147861474}},  {{{-1., 0.889197, 0.765525, 0.922963, 0.0469026, 0.741518, -1.,  0.796846, 0.447081, 0.781536, 0.167161, 0.567457, -1., 0.626933,  0.988851, 0.940513, 0.304809, 0.140692, -1., 0.561284, 0.687938,  0.617557, 0.159902, 0.557118, -1., -1., 0.710097, 0.123587,  0.560564, 0.380393, 0.8209, -1., 0.358061, 0.637601, 0.33349,  0.079382, 0.561215, -1., 0.19052, 0.551954, 0.912221, 0.993758,  0.563587, -1., 0.563103, 0.971709, 0.688949, 0.422895,  0.0018186, -1., -1., 0.28377, 0.0713913, 0.262993, 0.4447,  0.573674, -1., 0.947805, 0.702429, 0.064307, 0.752774,  0.589744, -1., 0.0648274, 0.730817, 0.673392, 0.0285289,  0.874307, -1., 0.178863, 0.76117, 0.0347713, 0.31072,  0.61576, -1., -1., 0.789462, 0.345823, 0.887826, 0.613941,  0.505692, -1., 0.274431, 0.624833, 0.169241, 0.932018,  0.326626, -1., 0.922404, 0.104934, 0.179244, 0.736883,  0.857576, -1., 0.374118, 0.505853, 0.708354, 0.983269,  0.195255, -1.},  0., {1., 0.744683, 0.673583, 0.672549, 0.579495, 0.955221, 1.,  0.32776, 0.784723, 0.965554, 0.449529, 0.0533287, 1., 0.15989,  0.796313, 0.517511, 0.726702, 0.237487, 1., 0.691379, 0.338267,  0.989819, 0.37991, 0.317261, 1.}, {0.832414, 0.281466, 0.396641,  0.122006, 0.0877311}, {0.607883, 0.724092, 0.542511, 0.13251,  0.280123}, {{2, 2}, {2, 2}, {0, 4}, {3, 1}, {2, 2}, {2, 3}, {1,  4}, {2, 1}}, 0.939369, 0.576957, 0.682981, 0.226794,  0.779478}, {-23.921711901388196485, -0.53663744666777733553}},  {{{-1., 0.780644, 0.16547, 0.500092, 0.541992, 0.0892658, -1.,  0.827204, 0.510273, 0.162082, 0.772005, 0.99479, -1., 0.228807,  0.765441, 0.649999, 0.907059, 0.620924, -1., 0.0413489, 0.107488,  0.774549, 0.340801, 0.10198, -1., -1., 0.530531, 0.0915674,  0.114006, 0.322502, 0.749887, -1., 0.926097, 0.613914, 0.78051,  0.660621, 0.0988936, -1., 0.103641, 0.618428, 0.888616, 0.104103,  0.874834, -1., 0.852987, 0.238617, 0.197044, 0.25391,  0.811638, -1., -1., 0.131129, 0.422496, 0.913109, 0.709658,  0.600598, -1., 0.330928, 0.799103, 0.387157, 0.850711,  0.404831, -1., 0.185189, 0.606647, 0.19009, 0.305938,  0.0815474, -1., 0.988219, 0.301474, 0.201834, 0.206713,  0.135232, -1., -1., 0.0628566, 0.00478959, 0.952803, 0.323593,  0.931727, -1., 0.582294, 0.0396939, 0.613935, 0.331129,  0.251365, -1., 0.240591, 0.226778, 0.480418, 0.846534,  0.0554023, -1., 0.620131, 0.290328, 0.540597, 0.973855,  0.631912, -1.},  0., {1., 0.988854, 0.338763, 0.767142, 0.496681, 0.925997, 1.,  0.333973, 0.814339, 0.173088, 0.99427, 0.751679, 1., 0.774645,  0.559153, 0.663141, 0.500314, 0.534054, 1., 0.332375, 0.182723,  0.65378, 0.478651, 0.712244, 1.}, {0.892395, 0.113183, 0.504796,  0.0803315, 0.903541}, {0.77442, 0.737655, 0.583651, 0.977544,  0.440447}, {{2, 1}, {0, 2}, {1, 3}, {2, 2}, {3, 4}, {2, 3}, {2,  2}, {1, 0}}, 0.923316, 0.410563, 0.983274, 0.688768,  0.148672}, {-33.370292902876612242, -0.84927947683580930927}},  {{{-1., 0.85141, 0.320134, 0.188454, 0.614618, 0.519035, -1.,  0.137411, 0.534675, 0.135967, 0.806791, 0.245016, -1., 0.421492,  0.63117, 0.72646, 0.341475, 0.647071, -1., 0.893516, 0.142809,  0.36393, 0.206624, 0.970199, -1., -1., 0.732246, 0.380656,  0.517856, 0.821528, 0.880836, -1., 0.0605226, 0.329402, 0.20691,  0.361801, 0.923112, -1., 0.794727, 0.0709432, 0.555009, 0.678096,  0.373236, -1., 0.439773, 0.828549, 0.336621, 0.726164,  0.546257, -1., -1., 0.68574, 0.972691, 0.51954, 0.576058,  0.953494, -1., 0.592035, 0.00168387, 0.75453, 0.0726582,  0.531512, -1., 0.672282, 0.54762, 0.710858, 0.6084, 0.877555, -1.,  0.476677, 0.155848, 0.930304, 0.504319, 0.0369039, -1., -1.,  0.327299, 0.593683, 0.778155, 0.490647, 0.641559, -1., 0.620992,  0.258615, 0.914589, 0.688064, 0.0289579, -1., 0.256931, 0.160059,  0.615406, 0.497446, 0.584649, -1., 0.612439, 0.904549, 0.889046,  0.707094, 0.135762, -1.},  0., {1., 0.748701, 0.958741, 0.202775, 0.098858, 0.421402, 1.,  0.365058, 0.42462, 0.608211, 0.779843, 0.744066, 1., 0.166006,  0.693623, 0.0917786, 0.715108, 0.909074, 1., 0.533564, 0.476372,  0.217662, 0.324425, 0.921125, 1.}, {0.571824, 0.328616, 0.617331,  0.785363, 0.823123}, {0.369875, 0.414555, 0.686505, 0.401721,  0.00481646}, {{2, 3}, {2, 1}, {2, 0}, {0, 1}, {1, 4}, {2, 1}, {3,  3}, {1, 1}}, 0.989935, 0.0782937, 0.621878, 0.260751,  0.823929}, {-41.692640393376944315, -0.84842443732286479987}},  {{{-1., 0.384671, 0.5301, 0.545643, 0.914855, 0.851107, -1.,  0.0537273, 0.327981, 0.59043, 0.929982, 0.481904, -1., 0.999365,  0.973099, 0.144619, 0.658781, 0.62949, -1., 0.558544, 0.458114,  0.257059, 0.624674, 0.568609, -1., -1., 0.37982, 0.635181,  0.363923, 0.744679, 0.995149, -1., 0.105081, 0.81828, 0.829824,  0.144042, 0.0513539, -1., 0.490299, 0.239395, 0.21406, 0.56945,  0.490935, -1., 0.266296, 0.0694407, 0.91067, 0.861445,  0.707752, -1., -1., 0.611327, 0.65361, 0.236771, 0.139144,  0.231507, -1., 0.0184292, 0.872848, 0.394464, 0.236358,  0.913348, -1., 0.0545678, 0.56464, 0.0923161, 0.861994,  0.564268, -1., 0.325245, 0.878256, 0.292544, 0.0733338,  0.0589496, -1., -1., 0.808816, 0.381874, 0.211889, 0.351197,  0.197489, -1., 0.728264, 0.975118, 0.212054, 0.965982,  0.709835, -1., 0.10227, 0.81759, 0.729624, 0.796487,  0.0477026, -1., 0.25295, 0.637308, 0.934493, 0.483434,  0.927704, -1.},  0., {1., 0.759052, 0.641949, 0.4101, 0.868755, 0.950236, 1.,  0.260075, 0.198211, 0.517557, 0.752747, 0.531811, 1., 0.223093,  0.305503, 0.786765, 0.821976, 0.120822, 1., 0.487914, 0.057141,  0.0254887, 0.0731198, 0.234964, 1.}, {0.419833, 0.0909959,  0.589686, 0.30726, 0.660781}, {0.449047, 0.179585, 0.438505,  0.710546,  0.188972}, {{1, 3}, {2, 3}, {0, 4}, {1, 3}, {2, 1}, {2, 2}, {1,  2}, {2, 1}}, 0.981374, 0.920948, 0.957799, 0.657162,  0.758282}, {-29.112381968992402432, -0.89299421548952279337}},  {{{-1., 0.615445, 0.171034, 0.835186, 0.637459, 0.127531, -1.,  0.113893, 0.809697, 0.564339, 0.892567, 0.694059, -1., 0.718701,  0.974654, 0.585307, 0.033278, 0.269655, -1., 0.795068, 0.146802,  0.322732, 0.0806824, 0.813694, -1., -1., 0.225854, 0.364934,  0.423521, 0.0554126, 0.610409, -1., 0.1939, 0.588335, 0.417953,  0.482878, 0.0800077, -1., 0.778637, 0.853614, 0.590311, 0.385948,  0.0599358, -1., 0.87896, 0.00500353, 0.35267, 0.790281,  0.0838921, -1., -1., 0.858201, 0.0299378, 0.709599, 0.270198,  0.632348, -1., 0.665004, 0.286078, 0.214785, 0.0219384,  0.471104, -1., 0.697744, 0.796832, 0.53906, 0.391096,  0.919106, -1., 0.943218, 0.94875, 0.00514785, 0.859171,  0.0642573, -1., -1., 0.943746, 0.652478, 0.0688893, 0.980365,  0.0855446, -1., 0.62254, 0.35929, 0.710167, 0.453197,  0.957536, -1., 0.0732122, 0.495382, 0.431259, 0.486432,  0.375469, -1., 0.69855, 0.892198, 0.0953358, 0.456362,  0.755332, -1.},  0., {1., 0.943449, 0.090188, 0.597192, 0.691075, 0.999702, 1.,  0.43771, 0.528302, 0.71071, 0.914158, 0.815171, 1., 0.169012,  0.000542493, 0.460961, 0.857635, 0.0957999, 1., 0.505161,  0.029702, 0.371203, 0.720331, 0.806611, 1.}, {0.137504, 0.275867,  0.263969, 0.0512784, 0.194055}, {0.185679, 0.666777, 0.360203,  0.194353,  0.747969}, {{0, 0}, {1, 1}, {3, 2}, {3, 1}, {1, 4}, {3, 4}, {3,  1}, {0, 2}}, 0.138475, 0.649494, 0.280195, 0.932798,  0.969463}, {-13.287922545989781527, 0.23596937832915911858}}}; 

  for (unsigned int i = 0; i < sizeof (data) / sizeof (data[0]); ++i)
    {
      NormalDistribution prior_logbeta = 
        normal_distribution (data[i].inputs.logbetamean, 
                             data[i].inputs.logbetastddev);

      float hq = 0;
      for (unsigned int k = 0; k < 5; ++k)
        {
          hq -= data[i].inputs.pz[k] * log (data[i].inputs.pz[k]);
        }

      float logbeta = data[i].inputs.logbeta;
      float q = mstep (data[i].inputs.alpha,
                       5,
                       &logbeta,
                       data[i].inputs.pz,
                       data[i].inputs.priorz,
                       data[i].inputs.ratings,
                       8,
                       &prior_logbeta.base,
                       0,
                       hq,
                       1000,
                       1e-7);

      assert (fabs (q - data[i].desired_output.newq) <=
              1e-3 * (1 + fabs (q) + fabs (data[i].desired_output.newq)) ||
              ( fprintf (stderr, "%g ?= %g\n", q, data[i].desired_output.newq),
                0 ));
      
      assert (fabs (logbeta - data[i].desired_output.logbeta) <=
              1e-3 * (1 + fabs (logbeta) + fabs (data[i].desired_output.logbeta)) ||
              ( fprintf (stderr, "(oops ... but this is ok) %g ?= %g\n", logbeta, data[i].desired_output.logbeta),
                1 ));

    }
}

int 
main (void)
{
  srand48 (69);

  test_mstep ();

  return 0;
}
